# pylint: disable=unused-import,abstract-method,unused-argument,no-member

##########################################################################################
#### Imports
##########################################################################################

from flask import Markup
from statistics import mean
import random
import re
from typing import Union, List
import time
from dallinger import db

import psynet.experiment

from psynet.timeline import get_template
from psynet.field import claim_field
from psynet.participant import Participant, get_participant
from psynet.timeline import (
    Page,
    Timeline,
    PageMaker,
    CodeBlock,
    while_loop,
    conditional,
    switch,
    FailedValidation
)
from psynet.page import (
    InfoPage,
    SuccessfulEndPage,
    SliderPage,
    NAFCPage,
    NumberInputPage,
    VolumeCalibration
)
from psynet.headphone import headphone_check
from psynet.trial.chain import ChainNetwork
from psynet.trial.gibbs import (
    GibbsNetwork, GibbsTrial, GibbsNode, GibbsSource, GibbsTrialMaker
)
from timeline_parts import (
    instructions_training,
    instructions_experiment,
    final_questionnaire
)

import logging
import os

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__file__)

# import rpdb

TARGETS = ["pleasant"]
INTERVALS = ["interval1", "interval2"]
MIDI_RANGE = 11
MIN_INTERVAL = 0.5
INITIAL_RECRUITMENT_SIZE = 20
TIME_ESTIMATE_PER_TRIAL = 20


def get_template(name):
    assert isinstance(name, str)
    data_path = os.path.join('templates', name)
    with open(data_path, encoding='utf-8') as fp:
        template_str = fp.read()
    return template_str


class ChordSliderPage(SliderPage):
    def __init__(
            self,
            label: str,
            prompt: Union[str, Markup],
            selected_idx: int,
            starting_values: List[int],
            f0: float,
            sliders: List[str],
            reverse_scale: bool,
            time_estimate=None,
            **kwargs
    ):
        assert selected_idx >= 0 and selected_idx < len(INTERVALS)
        self.prompt = prompt
        self.selected_idx = selected_idx
        self.starting_values = starting_values

        not_selected_idxs = list(range(len(INTERVALS)))
        not_selected_idxs.remove(selected_idx)
        not_selected_intervals = [INTERVALS[i] for i in not_selected_idxs]
        not_selected_values = [starting_values[i] for i in not_selected_idxs]
        hidden_inputs = dict(zip(not_selected_intervals, not_selected_values))
        kwargs['template_arg'] = {
            'hidden_inputs': hidden_inputs,
        }
        super().__init__(
            time_estimate=time_estimate,
            template_str=get_template("chord-slider.html"),
            label=label,
            prompt=prompt,
            start_value=starting_values[selected_idx],
            min_value=MIN_INTERVAL,
            max_value=MIDI_RANGE,
            slider_id=INTERVALS[selected_idx],
            reverse_scale=reverse_scale,
            js_vars={
                "f0": f0,
                "sliders": sliders
            },
            **kwargs
        )

    def metadata(self, **kwargs):
        return {
            "prompt": self.prompt,
            "selected_idx": self.selected_idx,
            "starting_values": self.starting_values
        }


class CustomNetwork(GibbsNetwork):
    __mapper_args__ = {"polymorphic_identity": "custom_network"}

    vector_length = len(INTERVALS)

    def random_sample(self, i):
        assert MIN_INTERVAL < MIDI_RANGE
        return MIN_INTERVAL + (MIDI_RANGE - MIN_INTERVAL)*random.random()

    def make_definition(self):
        return {
            "target": self.balance_across_networks(TARGETS)
        }

    # Minimal example of an async_post_grow_network function
    run_async_post_grow_network = False
    def async_post_grow_network(self):
        logger.info("Running custom async_post_grow_network function (network id = %i)", self.id)

class CustomTrial(GibbsTrial):
    __mapper_args__ = {"polymorphic_identity": "custom_trial"}

    # If True, then the starting value for the free parameter is resampled
    # on each trial.
    resample_free_parameter = True
    minimal_interactions = 3
    minimal_time = 3.0

    def show_trial(self, experiment, participant):
        selected_interval = INTERVALS[self.active_index]
        target = self.network.definition["target"]
        prompt = Markup(
            "Adjust the slider to match the following word as well as possible: "
            f"<strong>{target}</strong>"
        )
        return ChordSliderPage(
            "chord_trial",
            prompt,
            starting_values=self.initial_vector,
            selected_idx=self.active_index,
            reverse_scale=self.reverse_scale,
            time_estimate=5,
            minimal_interactions = self.minimal_interactions,
            minimal_time = self.minimal_time,
            f0=self.definition["f0"],
            sliders = self.definition["sliders"],
        )

    def make_definition(self, experiment, participant):
        definition = super().make_definition(experiment, participant)
        definition["f0"] = 60.0 + 10 * (random.random() - 0.5)
        definition["sliders"] = INTERVALS
        return definition

    # Minimal example of an async_post_trial function
    run_async_post_trial = False
    def async_post_trial(self):
        logger.info("Running custom async post trial (id = %i)", self.id)

class CustomNode(GibbsNode):
    __mapper_args__ = {"polymorphic_identity": "custom_node"}


class CustomSource(GibbsSource):
    __mapper_args__ = {"polymorphic_identity": "custom_source"}

class CustomTrialMaker(GibbsTrialMaker):
    give_end_feedback_passed = True
    performance_threshold = -1.0

    def get_end_feedback_passed_page(self, score):
        score_to_display = "NA" if score is None else f"{(100 * score):.0f}"

        return InfoPage(
            Markup(f"Your consistency score was <strong>{score_to_display}&#37;</strong>."),
            time_estimate=5
        )

    def compute_bonus(self, score, passed):
        if score is None:
            return 0.0
        else:
            return max(0.0, 2 * (score - 0.5)) # chance-rate is 0.5


trial_maker_training = CustomTrialMaker(
    network_class=CustomNetwork,
    trial_class=CustomTrial,
    node_class=CustomNode,
    source_class=CustomSource,
    phase="training",  # can be whatever you like
    time_estimate_per_trial=TIME_ESTIMATE_PER_TRIAL,
    chain_type="within",  # can be "within" or "across"
    num_trials_per_participant=6,
    num_nodes_per_chain=4, # note that the final node receives no trials
    num_chains_per_participant=1,  # set to None if chain_type="across"
    num_chains_per_experiment=None,  # set to None if chain_type="within"
    trials_per_node=1,
    active_balancing_across_chains=False,
    check_performance_at_end=False,
    check_performance_every_trial=False,
    propagate_failure=False,
    recruit_mode="num_participants",
    target_num_participants=0
)

trial_maker_experiment = CustomTrialMaker(
    network_class=CustomNetwork,
    trial_class=CustomTrial,
    node_class=CustomNode,
    source_class=CustomSource,
    phase="experiment",  # can be whatever you like
    time_estimate_per_trial=TIME_ESTIMATE_PER_TRIAL,
    chain_type="across",  # can be "within" or "across"
    num_trials_per_participant=20,
    num_nodes_per_chain=2*20 + 1, # note that the final node receives no trials
    num_chains_per_participant=None,  # set to None if chain_type="across"
    num_chains_per_experiment=50,  # set to None if chain_type="within"
    trials_per_node=1,
    active_balancing_across_chains=False,
    check_performance_at_end=True,
    check_performance_every_trial=False,
    propagate_failure=False,
    recruit_mode="num_trials",
    target_num_participants=None,
    num_repeat_trials=6
)

##########################################################################################
#### Experiment
##########################################################################################

# Weird bug: if you instead import Experiment from psynet.experiment,
# Dallinger won't allow you to override the bonus method
# (or at least you can override it but it won't work).
class Exp(psynet.experiment.Experiment):
    timeline = Timeline(
        VolumeCalibration(),
        headphone_check(),
        InfoPage("You passed the headphone screening task! Congratulations.", time_estimate=3),
        instructions_training,
        trial_maker_training,
        instructions_experiment,
        trial_maker_experiment,
        final_questionnaire,
        SuccessfulEndPage()
    )

    def __init__(self, session=None):
        super().__init__(session)

        # Change this if you want to simulate multiple simultaneous participants.
        self.initial_recruitment_size = INITIAL_RECRUITMENT_SIZE


extra_routes = Exp().extra_routes()
